PI-MAE for 75 Noise Mask

In [1]:
!pip install tensorflow_addons
Requirement already satisfied: tensorflow_addons in /home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages (0.23.0)
Requirement already satisfied: packaging in /home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages (from tensorflow_addons) (21.3)
Requirement already satisfied: typeguard<3.0.0,>=2.7 in /home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages (from tensorflow_addons) (2.13.3)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages (from packaging->tensorflow_addons) (3.1.1)
In [2]:
from tensorflow.keras import layers
import tensorflow_addons as tfa
from tensorflow import keras
import tensorflow as tf

import matplotlib.pyplot as plt
import numpy as np
import random

SEED = 42
keras.utils.set_random_seed(SEED)

from keras import backend
backend.set_image_data_format('channels_last') #channels_first for NCHW

tf.config.run_functions_eagerly(True)
2024-03-29 01:28:57.387737: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-03-29 01:28:57.434058: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-03-29 01:28:57.434095: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-03-29 01:28:57.435225: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-03-29 01:28:57.442214: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-03-29 01:28:58.261518: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
/home/ec2-user/anaconda3/envs/tensorflow2_p310/lib/python3.10/site-packages/tensorflow_addons/utils/tfa_eol_msg.py:23: UserWarning: 

TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 

For more information see: https://github.com/tensorflow/addons/issues/2807 

  warnings.warn(
In [3]:
import os
os.getcwd()
Out[3]:
'/home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict'
In [4]:
BUFFER_SIZE = 1024
BATCH_SIZE = 256
AUTO = tf.data.AUTOTUNE
INPUT_SHAPE = (32, 32, 3)
LEARNING_RATE = 5e-3
WEIGHT_DECAY = 1e-4
EPOCHS = 50
IMAGE_SIZE = 48  
PATCH_SIZE = 3  
NUM_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2
MASK_PROPORTION = 0.75
LAYER_NORM_EPS = 1e-6
ENC_PROJECTION_DIM = 128
DEC_PROJECTION_DIM = 64
ENC_NUM_HEADS = 4
ENC_LAYERS = 6
DEC_NUM_HEADS = 4
DEC_LAYERS = (
    2
)
ENC_TRANSFORMER_UNITS = [
    ENC_PROJECTION_DIM * 2,
    ENC_PROJECTION_DIM,
]
DEC_TRANSFORMER_UNITS = [
    DEC_PROJECTION_DIM * 2,
    DEC_PROJECTION_DIM,
]
In [5]:
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
(x_train, y_train), (x_val, y_val) = (
    (x_train[:40000], y_train[:40000]),
    (x_train[40000:], y_train[40000:]),
)

x_train_32x32 = []
for idx, train in enumerate(x_train):
    x_train_32x32.append(
        np.pad(train, pad_width=((2, 2), (2, 2)), mode='constant')
    )
x_train = np.array(x_train_32x32)

x_test_32x32 = []
for idx, test in enumerate(x_test):
    x_test_32x32.append(
        np.pad(test, pad_width=((2, 2), (2, 2)), mode='constant')
    )
x_test = np.array(x_test_32x32)

x_val_32x32 = []
for idx, val in enumerate(x_val):
    x_val_32x32.append(
        np.pad(val, pad_width=((2, 2), (2, 2)), mode='constant')
    )
x_val = np.array(x_val_32x32)

x_train = np.expand_dims(x_train, -1)
y_train = np.expand_dims(y_train, -1)

x_test = np.expand_dims(x_test, -1)
y_test = np.expand_dims(y_test, -1)

x_val = np.expand_dims(x_val, -1)
y_val = np.expand_dims(y_val, -1)

x_train = tf.image.grayscale_to_rgb(
tf.convert_to_tensor(x_train),
name=None
)

x_test = tf.image.grayscale_to_rgb(
tf.convert_to_tensor(x_test),
name=None
)

x_val = tf.image.grayscale_to_rgb(
tf.convert_to_tensor(x_val),
name=None
)

train_ds = tf.data.Dataset.from_tensor_slices(x_train)
train_ds = train_ds.shuffle(BUFFER_SIZE).batch(BATCH_SIZE).prefetch(AUTO)

val_ds = tf.data.Dataset.from_tensor_slices(x_val)
val_ds = val_ds.batch(BATCH_SIZE).prefetch(AUTO)

test_ds = tf.data.Dataset.from_tensor_slices(x_test)
test_ds = test_ds.batch(BATCH_SIZE).prefetch(AUTO)
2024-03-29 01:29:02.744592: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-03-29 01:29:02.793554: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-03-29 01:29:02.797473: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-03-29 01:29:02.803726: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-03-29 01:29:02.807190: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-03-29 01:29:02.810437: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-03-29 01:29:02.958184: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-03-29 01:29:02.959442: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-03-29 01:29:02.960593: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-03-29 01:29:02.961689: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1929] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13775 MB memory:  -> device: 0, name: Tesla T4, pci bus id: 0000:00:1e.0, compute capability: 7.5
In [6]:
def get_train_augmentation_model():
    model = keras.Sequential(
        [
            layers.Rescaling(1 / 255.0),
            layers.Resizing(INPUT_SHAPE[0] + 20, INPUT_SHAPE[0] + 20),
            layers.RandomCrop(IMAGE_SIZE, IMAGE_SIZE),
            layers.RandomFlip("horizontal"),
        ],
        name="train_data_augmentation",
    )
    return model


def get_test_augmentation_model():
    model = keras.Sequential(
        [layers.Rescaling(1 / 255.0), layers.Resizing(IMAGE_SIZE, IMAGE_SIZE, interpolation='area'),],
        name="test_data_augmentation",
    )
    return model
In [7]:
class Patches(layers.Layer):
    def __init__(self, patch_size=PATCH_SIZE, **kwargs):
        super().__init__(**kwargs)
        self.patch_size = patch_size
        self.resize = layers.Reshape((-1, patch_size * patch_size * 3))

    def call(self, images):
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )

        patches = self.resize(patches)

        return patches

    def show_patched_image(self, images, patches):
        idx = np.random.choice(patches.shape[0])
        plt.figure(figsize=(4, 4))
        plt.imshow(keras.utils.array_to_img(images[idx]))
        plt.axis("off")
        plt.show()

        n = int(np.sqrt(patches.shape[1]))
        plt.figure(figsize=(4, 4))
        for i, patch in enumerate(patches[idx]):
            ax = plt.subplot(n, n, i + 1)
            patch_img = tf.reshape(patch, (self.patch_size, self.patch_size, 3))
            plt.imshow(keras.utils.img_to_array(patch_img))
            plt.axis("off")
        plt.show()

        return idx

    def reconstruct_from_patch(self, patch):
        num_patches = patch.shape[0]
        n = int(np.sqrt(num_patches))
        patch = tf.reshape(patch, (num_patches, self.patch_size, self.patch_size, 3))
        rows = tf.split(patch, n, axis=0)
        rows = [tf.concat(tf.unstack(x), axis=1) for x in rows]
        reconstructed = tf.concat(rows, axis=0)

        plt.title('reconstructed')
        plt.imshow(reconstructed, cmap='gray')
        plt.show()

        return reconstructed

    def reconstruct_from_patch_and_find_dark_spots_batch(self, patch):

        unmask, mask = [], []

        for idx in range(patch.shape[0]):
            num_patches = patch[idx].shape[0]
            n = int(np.sqrt(num_patches))
            print('n', n)
            p = tf.reshape(patch[idx], (num_patches, self.patch_size, self.patch_size, 3))
            rows = tf.split(p, n, axis=0)
            rows = [tf.concat(tf.unstack(x), axis=1) for x in rows]
            reconstructed = tf.concat(rows, axis=0)
            reconstructed = reconstructed.numpy()

            sum_patches = []

            for idx, elem in enumerate(p):
                sum_patches.append(np.sum(elem.numpy()))

            mean_patches_indices = np.argsort(sum_patches)
            mean_patches_indices = np.flip(mean_patches_indices)

            mask_amount = 64

            unmask.append(mean_patches_indices[:mask_amount])
            mask.append(mean_patches_indices[mask_amount:])

        unmask = np.array(unmask)
        mask = np.array(mask)

        return unmask, mask

    def reconstruct_from_patch_and_find_dark_spots(self, patch):

        num_patches = patch.shape[0]
        n = int(np.sqrt(num_patches))
        patch = tf.reshape(patch, (num_patches, self.patch_size, self.patch_size, 3))
        rows = tf.split(patch, n, axis=0)
        rows = [tf.concat(tf.unstack(x), axis=1) for x in rows]
        reconstructed = tf.concat(rows, axis=0)
        reconstructed = reconstructed.numpy()

        sum_patches = []

        for idx, elem in enumerate(patch):
            sum_patches.append(np.sum(elem.numpy()))

        mean_patches_indices = np.argsort(sum_patches)
        mean_patches_indices = np.flip(mean_patches_indices)

        mask_amount = 64

        unmask = mean_patches_indices[:mask_amount]
        mask = mean_patches_indices[mask_amount:]

        unmask = mean_patches_indices[:mask_amount]
        mask = mean_patches_indices[mask_amount:]

        return unmask, mask


    def reconstruct_from_patch_and_find_dark_spots_256(self, patch):

        total_mask = []
        total_unmask = []

        for picture_idx in range(patch.shape[0]):
            unmask, mask = self.reconstruct_from_patch_and_find_dark_spots(patch[picture_idx])
            total_unmask.append(unmask)
            total_mask.append(mask)

        return total_unmask, total_mask
In [8]:
image_batch = next(iter(train_ds))
augmentation_model = get_train_augmentation_model()
augmented_images = augmentation_model(image_batch)

patch_layer = Patches()

patches = patch_layer(images=augmented_images)

random_index = patch_layer.show_patched_image(images=augmented_images, patches=patches)

image = patch_layer.reconstruct_from_patch(patches[random_index])
plt.imshow(image)
plt.axis("off")
plt.show()
In [9]:
class PatchEncoder(layers.Layer):
    def __init__(
        self,
        patch_size=PATCH_SIZE,
        projection_dim=ENC_PROJECTION_DIM,
        mask_proportion=MASK_PROPORTION,
        downstream=False,
        alreadyMasked=False,
        setMasking=False,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.patch_size = patch_size
        self.projection_dim = projection_dim
        self.mask_proportion = mask_proportion
        self.downstream = downstream
        self.alreadyMasked = alreadyMasked
        self.mask_indicies = 0
        self.unmask_indicies = 0
        self.patches = 0
        self.tmp_x = 0

        self.mask_token = tf.Variable(
            tf.random.normal([1, patch_size * patch_size * 3]), trainable=True
        )

    def build(self, input_shape):
        (_, self.num_patches, self.patch_area) = input_shape

        self.projection = layers.Dense(units=self.projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=self.num_patches, output_dim=self.projection_dim
        )
        self.num_mask = int(self.mask_proportion * self.num_patches)

    def call(self, patches):
        batch_size = tf.shape(patches)[0]
        positions = tf.range(start=0, limit=self.num_patches, delta=1)
        pos_embeddings = self.position_embedding(positions[tf.newaxis, ...])
        pos_embeddings = tf.tile(
            pos_embeddings, [batch_size, 1, 1]
        )
        patch_embeddings = (
            self.projection(patches) + pos_embeddings
        )

        if self.downstream:
            return patch_embeddings
        elif self.alreadyMasked:
            if self.setMasking == False:
                mask_indices, unmask_indices = self.get_already_masked_indices(patch_embeddings, patches, batch_size)
            else:
                tmp = self.tmp_x[::2, ::2].flatten()
                mask_indices = np.array([np.where(tmp == 255)[0]])
                unmask_indices = np.array([np.where(tmp != 255)[0]])

                if patch_embeddings.shape[0] != 1:
                    mask_indices = np.tile(mask_indices, (patch_embeddings.shape[0], 1))
                    unmask_indices = np.tile(unmask_indices, (patch_embeddings.shape[0], 1))

            unmasked_embeddings = tf.gather(
                patch_embeddings, unmask_indices, axis=1, batch_dims=1
            )
            unmasked_positions = tf.gather(
                pos_embeddings, unmask_indices, axis=1, batch_dims=1
            )
            masked_positions = tf.gather(
                pos_embeddings, mask_indices, axis=1, batch_dims=1
            )
            mask_tokens = tf.repeat(self.mask_token, repeats=self.num_mask, axis=0)
            mask_tokens = tf.repeat(
                mask_tokens[tf.newaxis, ...], repeats=batch_size, axis=0
            )
            masked_embeddings = self.projection(mask_tokens) + masked_positions

            return (
                unmasked_embeddings,  # Input to the encoder.
                masked_embeddings,  # First part of input to the decoder.
                unmasked_positions,  # Added to the encoder outputs.
                mask_indices,  # The indices that were masked.
                unmask_indices,  # The indices that were unmaksed.
            )

        else:
            mask_indices, unmask_indices = self.get_random_indices(batch_size)
            unmasked_embeddings = tf.gather(
                patch_embeddings, unmask_indices, axis=1, batch_dims=1
            )
            unmasked_positions = tf.gather(
                pos_embeddings, unmask_indices, axis=1, batch_dims=1
            )
            masked_positions = tf.gather(
                pos_embeddings, mask_indices, axis=1, batch_dims=1
            )
            mask_tokens = tf.repeat(self.mask_token, repeats=self.num_mask, axis=0)
            mask_tokens = tf.repeat(
                mask_tokens[tf.newaxis, ...], repeats=batch_size, axis=0
            )
            masked_embeddings = self.projection(mask_tokens) + masked_positions
            return (
                unmasked_embeddings,  # Input to the encoder.
                masked_embeddings,  # First part of input to the decoder.
                unmasked_positions,  # Added to the encoder outputs.
                mask_indices,  # The indices that were masked.
                unmask_indices,  # The indices that were unmaksed.
            )

    def get_already_masked_indices(self, patch_embeddings, patches, batch_size):
        plt.title('in get already masked indicies')
        plt.imshow(patches.numpy()[0])
        plt.show()

        unmask_indices, mask_indices = patch_layer.reconstruct_from_patch_and_find_dark_spots_batch(patches.numpy())
        mask_indices = np.array(mask_indices)
        unmask_indices = np.array(unmask_indices)
        return mask_indices, unmask_indices


    def get_random_indices(self, batch_size):
        uni = tf.random.uniform(shape=(batch_size, self.num_patches))
        rand_indices = tf.argsort(uni, axis=-1)
        mask_indices = rand_indices[:, : self.num_mask]
        unmask_indices = rand_indices[:, self.num_mask:]
        return mask_indices, unmask_indices

    def generate_masked_image(self, patches, unmask_indices):
        idx = np.random.choice(patches.shape[0])
        patch = patches[idx]
        unmask_index = unmask_indices[idx]
        new_patch = np.zeros_like(patch)
        count = 0
        for i in range(unmask_index.shape[0]):
            new_patch[unmask_index[i]] = patch[unmask_index[i]]
        return new_patch, idx
In [10]:
patch_encoder = PatchEncoder()

(
    unmasked_embeddings,
    masked_embeddings,
    unmasked_positions,
    mask_indices,
    unmask_indices,
) = patch_encoder(patches=patches)


new_patch, random_index = patch_encoder.generate_masked_image(patches, unmask_indices)

plt.figure(figsize=(10, 10))
plt.subplot(1, 2, 1)
img = patch_layer.reconstruct_from_patch(new_patch)
plt.imshow(keras.utils.array_to_img(img))
plt.axis("off")
plt.title("Masked")
plt.subplot(1, 2, 2)
img = augmented_images[random_index]
plt.imshow(keras.utils.array_to_img(img))
plt.axis("off")
plt.title("Original")
plt.show()
In [11]:
def mlp(x, dropout_rate, hidden_units):
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x
In [12]:
def create_encoder(num_heads=ENC_NUM_HEADS, num_layers=ENC_LAYERS):
    inputs = layers.Input((None, ENC_PROJECTION_DIM))
    x = inputs

    for _ in range(num_layers):
        x1 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x)
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=ENC_PROJECTION_DIM, dropout=0.1
        )(x1, x1)
        x2 = layers.Add()([attention_output, x])
        x3 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x2)
        x3 = mlp(x3, hidden_units=ENC_TRANSFORMER_UNITS, dropout_rate=0.1)
        x = layers.Add()([x3, x2])

    outputs = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x)
    return keras.Model(inputs, outputs, name="pimae_encoder")
In [13]:
def create_decoder(
    num_layers=DEC_LAYERS, num_heads=DEC_NUM_HEADS, image_size=IMAGE_SIZE
):
    inputs = layers.Input((NUM_PATCHES, ENC_PROJECTION_DIM))
    x = layers.Dense(DEC_PROJECTION_DIM)(inputs)

    for _ in range(num_layers):
        x1 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x)
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=DEC_PROJECTION_DIM, dropout=0.1
        )(x1, x1)
        x2 = layers.Add()([attention_output, x])
        x3 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x2)
        x3 = mlp(x3, hidden_units=DEC_TRANSFORMER_UNITS, dropout_rate=0.1)
        x = layers.Add()([x3, x2])

    x = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x)
    x = layers.Flatten()(x)
    pre_final = layers.Dense(units=image_size * image_size * 3, activation="sigmoid")(x)
    outputs = layers.Reshape((image_size, image_size, 3))(pre_final)

    return keras.Model(inputs, outputs, name="pimae_decoder")
In [14]:
class PIMaskedAutoencoder(keras.Model):
    def __init__(
        self,
        train_augmentation_model,
        test_augmentation_model,
        patch_layer,
        patch_encoder,
        encoder,
        decoder,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.train_augmentation_model = train_augmentation_model
        self.test_augmentation_model = test_augmentation_model
        self.patch_layer = patch_layer
        self.patch_encoder = patch_encoder
        self.encoder = encoder
        self.decoder = decoder

    def calculate_loss(self, images, test=False):
        if test:
            augmented_images = self.test_augmentation_model(images)
        else:
            augmented_images = self.train_augmentation_model(images)

        patches = self.patch_layer(augmented_images)
        (
            unmasked_embeddings,
            masked_embeddings,
            unmasked_positions,
            mask_indices,
            unmask_indices,
        ) = self.patch_encoder(patches)
        encoder_outputs = self.encoder(unmasked_embeddings)
        encoder_outputs = encoder_outputs + unmasked_positions
        decoder_inputs = tf.concat([encoder_outputs, masked_embeddings], axis=1)
        decoder_outputs = self.decoder(decoder_inputs)
        decoder_patches = self.patch_layer(decoder_outputs)
        loss_patch = tf.gather(patches, mask_indices, axis=1, batch_dims=1)
        loss_output = tf.gather(decoder_patches, mask_indices, axis=1, batch_dims=1)
        total_loss = self.compiled_loss(loss_patch, loss_output)
        return total_loss, loss_patch, loss_output

    def train_step(self, images):
        with tf.GradientTape() as tape:
            total_loss, loss_patch, loss_output = self.calculate_loss(images)
        train_vars = [
            self.train_augmentation_model.trainable_variables,
            self.patch_layer.trainable_variables,
            self.patch_encoder.trainable_variables,
            self.encoder.trainable_variables,
            self.decoder.trainable_variables,
        ]
        grads = tape.gradient(total_loss, train_vars)
        tv_list = []
        for (grad, var) in zip(grads, train_vars):
            for g, v in zip(grad, var):
                tv_list.append((g, v))
        self.optimizer.apply_gradients(tv_list)
        self.compiled_metrics.update_state(loss_patch, loss_output)
        return {m.name: m.result() for m in self.metrics}

    def test_step(self, images):
        total_loss, loss_patch, loss_output = self.calculate_loss(images, test=True)
        self.compiled_metrics.update_state(loss_patch, loss_output)
        return {m.name: m.result() for m in self.metrics}
In [15]:
train_augmentation_model = get_train_augmentation_model()
test_augmentation_model = get_test_augmentation_model()
patch_layer = Patches()
patch_encoder = PatchEncoder()
encoder = create_encoder()
decoder = create_decoder()

patch_encoder.alreadyMasked = False

pimae_model = PIMaskedAutoencoder(
    train_augmentation_model=train_augmentation_model,
    test_augmentation_model=test_augmentation_model,
    patch_layer=patch_layer,
    patch_encoder=patch_encoder,
    encoder=encoder,
    decoder=decoder,
)
In [16]:
test_images = next(iter(test_ds))

class PIMAECustomCallback(keras.callbacks.Callback):

    def on_test_begin(self, logs=None):
        keys = list(logs.keys())
        test_augmented_images = self.model.test_augmentation_model(test_images)
        test_patches = self.model.patch_layer(test_augmented_images)
        (
            test_unmasked_embeddings,
            test_masked_embeddings,
            test_unmasked_positions,
            test_mask_indices,
            test_unmask_indices,
        ) = self.model.patch_encoder(test_patches)
        test_encoder_outputs = self.model.encoder(test_unmasked_embeddings)
        test_encoder_outputs = test_encoder_outputs + test_unmasked_positions
        test_decoder_inputs = tf.concat(
            [test_encoder_outputs, test_masked_embeddings], axis=1
        )
        test_decoder_outputs = self.model.decoder(test_decoder_inputs)
        test_masked_patch, idx = self.model.patch_encoder.generate_masked_image(
            test_patches, test_unmask_indices
        )
        print(f"\nIdx chosen: {idx}")
        original_image = test_augmented_images[idx]
        masked_image = self.model.patch_layer.reconstruct_from_patch(
            test_masked_patch
        )
        reconstructed_image = test_decoder_outputs[idx]

        fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(15, 5))
        ax[0].imshow(original_image)
        ax[0].set_title(f"Original:")

        ax[1].imshow(masked_image)
        ax[1].set_title(f"Masked:")

        ax[2].imshow(reconstructed_image)
        ax[2].set_title(f"Resonstructed:")

        plt.show()
        plt.close()

    def on_test_end(self, logs=None):
        keys = list(logs.keys())
        print(' ')
        print("PIMAE Stop testing; got log keys: {}".format(keys))
        print(' ')

    def on_predict_begin(self, logs=None):
        keys = list(logs.keys())
        print(' ')
        print("PIMAE Start predicting; got log keys: {}".format(keys))
        print(' ')

    def on_predict_end(self, logs=None):
        keys = list(logs.keys())
        print(' ')
        print("PIMAE Stop predicting; got log keys: {}".format(keys))
        print(' ')
In [17]:
class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule):
    def __init__(
        self, learning_rate_base, total_steps, warmup_learning_rate, warmup_steps
    ):
        super().__init__()

        self.learning_rate_base = learning_rate_base
        self.total_steps = total_steps
        self.warmup_learning_rate = warmup_learning_rate
        self.warmup_steps = warmup_steps
        self.pi = tf.constant(np.pi)

    def __call__(self, step):
        if self.total_steps < self.warmup_steps:
            raise ValueError("Total_steps must be larger or equal to warmup_steps.")

        cos_annealed_lr = tf.cos(
            self.pi
            * (tf.cast(step, tf.float32) - self.warmup_steps)
            / float(self.total_steps - self.warmup_steps)
        )
        learning_rate = 0.5 * self.learning_rate_base * (1 + cos_annealed_lr)

        if self.warmup_steps > 0:
            if self.learning_rate_base < self.warmup_learning_rate:
                raise ValueError(
                    "Learning_rate_base must be larger or equal to "
                    "warmup_learning_rate."
                )
            slope = (
                self.learning_rate_base - self.warmup_learning_rate
            ) / self.warmup_steps
            warmup_rate = slope * tf.cast(step, tf.float32) + self.warmup_learning_rate
            learning_rate = tf.where(
                step < self.warmup_steps, warmup_rate, learning_rate
            )
        return tf.where(
            step > self.total_steps, 0.0, learning_rate, name="learning_rate"
        )


total_steps = int((len(x_train) / BATCH_SIZE) * EPOCHS)
warmup_epoch_percentage = 0.15
warmup_steps = int(total_steps * warmup_epoch_percentage)
scheduled_lrs = WarmUpCosine(
    learning_rate_base=LEARNING_RATE,
    total_steps=total_steps,
    warmup_learning_rate=0.0,
    warmup_steps=warmup_steps,
)

lrs = [scheduled_lrs(step) for step in range(total_steps)]
plt.plot(lrs)
plt.xlabel("Step", fontsize=14)
plt.ylabel("LR", fontsize=14)
plt.show()
In [ ]:
optimizer = tfa.optimizers.AdamW(learning_rate=scheduled_lrs, weight_decay=WEIGHT_DECAY)
pimae_model.compile(
    optimizer=optimizer, loss=keras.losses.MeanSquaredError(), metrics=["mae"]
)
history = pimae_model.fit(
    train_ds, epochs=EPOCHS, validation_data=val_ds, callbacks=[PIMAECustomCallback()],
)
loss, mae = pimae_model.evaluate(test_ds)
print(f"Loss: {loss:.2f}")
print(f"MAE: {mae:.2f}")
Epoch 1/50
157/157 [==============================] - ETA: 0s - loss: 0.0679 - mae: 0.1373
Idx chosen: 92
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 96s 603ms/step - loss: 0.0679 - mae: 0.1373 - val_loss: 0.0489 - val_mae: 0.1048
Epoch 2/50
157/157 [==============================] - ETA: 0s - loss: 0.0522 - mae: 0.1153
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Idx chosen: 14
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 95s 603ms/step - loss: 0.0522 - mae: 0.1153 - val_loss: 0.0463 - val_mae: 0.1010
Epoch 3/50
157/157 [==============================] - ETA: 0s - loss: 0.0454 - mae: 0.1051
Idx chosen: 106
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 95s 605ms/step - loss: 0.0454 - mae: 0.1051 - val_loss: 0.0435 - val_mae: 0.0974
Epoch 4/50
157/157 [==============================] - ETA: 0s - loss: 0.0383 - mae: 0.0998
Idx chosen: 71
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 95s 608ms/step - loss: 0.0383 - mae: 0.0998 - val_loss: 0.0384 - val_mae: 0.0996
Epoch 5/50
157/157 [==============================] - ETA: 0s - loss: 0.0347 - mae: 0.0953
Idx chosen: 188
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 95s 604ms/step - loss: 0.0347 - mae: 0.0953 - val_loss: 0.0364 - val_mae: 0.0945
Epoch 6/50
157/157 [==============================] - ETA: 0s - loss: 0.0305 - mae: 0.0868
Idx chosen: 20
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 94s 602ms/step - loss: 0.0305 - mae: 0.0868 - val_loss: 0.0339 - val_mae: 0.0886
Epoch 7/50
157/157 [==============================] - ETA: 0s - loss: 0.0270 - mae: 0.0791
Idx chosen: 102
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 95s 604ms/step - loss: 0.0270 - mae: 0.0791 - val_loss: 0.0318 - val_mae: 0.0806
Epoch 8/50
157/157 [==============================] - ETA: 0s - loss: 0.0243 - mae: 0.0732
Idx chosen: 121
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 94s 601ms/step - loss: 0.0243 - mae: 0.0732 - val_loss: 0.0299 - val_mae: 0.0808
Epoch 9/50
157/157 [==============================] - ETA: 0s - loss: 0.0214 - mae: 0.0666
Idx chosen: 210
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 95s 606ms/step - loss: 0.0214 - mae: 0.0666 - val_loss: 0.0270 - val_mae: 0.0721
Epoch 10/50
157/157 [==============================] - ETA: 0s - loss: 0.0195 - mae: 0.0624
Idx chosen: 214
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 95s 604ms/step - loss: 0.0195 - mae: 0.0624 - val_loss: 0.0273 - val_mae: 0.0699
Epoch 11/50
157/157 [==============================] - ETA: 0s - loss: 0.0183 - mae: 0.0595
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Idx chosen: 74
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 95s 604ms/step - loss: 0.0183 - mae: 0.0595 - val_loss: 0.0251 - val_mae: 0.0682
Epoch 12/50
157/157 [==============================] - ETA: 0s - loss: 0.0171 - mae: 0.0566
Idx chosen: 202
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 95s 603ms/step - loss: 0.0171 - mae: 0.0566 - val_loss: 0.0244 - val_mae: 0.0675
Epoch 13/50
157/157 [==============================] - ETA: 0s - loss: 0.0163 - mae: 0.0551
Idx chosen: 87
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 95s 603ms/step - loss: 0.0163 - mae: 0.0551 - val_loss: 0.0236 - val_mae: 0.0648
Epoch 14/50
157/157 [==============================] - ETA: 0s - loss: 0.0156 - mae: 0.0534
Idx chosen: 116
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 95s 606ms/step - loss: 0.0156 - mae: 0.0534 - val_loss: 0.0230 - val_mae: 0.0621
Epoch 15/50
157/157 [==============================] - ETA: 0s - loss: 0.0149 - mae: 0.0518
Idx chosen: 99
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 95s 604ms/step - loss: 0.0149 - mae: 0.0518 - val_loss: 0.0230 - val_mae: 0.0623
Epoch 16/50
157/157 [==============================] - ETA: 0s - loss: 0.0142 - mae: 0.0500
Idx chosen: 103
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 94s 602ms/step - loss: 0.0142 - mae: 0.0500 - val_loss: 0.0221 - val_mae: 0.0611
Epoch 17/50
157/157 [==============================] - ETA: 0s - loss: 0.0139 - mae: 0.0492
Idx chosen: 151
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 94s 601ms/step - loss: 0.0139 - mae: 0.0492 - val_loss: 0.0226 - val_mae: 0.0617
Epoch 18/50
157/157 [==============================] - ETA: 0s - loss: 0.0136 - mae: 0.0487
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Idx chosen: 130
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 95s 604ms/step - loss: 0.0136 - mae: 0.0487 - val_loss: 0.0217 - val_mae: 0.0601
Epoch 19/50
157/157 [==============================] - ETA: 0s - loss: 0.0131 - mae: 0.0474
Idx chosen: 149
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 95s 604ms/step - loss: 0.0131 - mae: 0.0474 - val_loss: 0.0229 - val_mae: 0.0623
Epoch 20/50
157/157 [==============================] - ETA: 0s - loss: 0.0128 - mae: 0.0466
Idx chosen: 52
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 94s 600ms/step - loss: 0.0128 - mae: 0.0466 - val_loss: 0.0225 - val_mae: 0.0614
Epoch 21/50
157/157 [==============================] - ETA: 0s - loss: 0.0125 - mae: 0.0459
Idx chosen: 1
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 94s 602ms/step - loss: 0.0125 - mae: 0.0459 - val_loss: 0.0217 - val_mae: 0.0597
Epoch 22/50
157/157 [==============================] - ETA: 0s - loss: 0.0121 - mae: 0.0448
Idx chosen: 87
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 95s 606ms/step - loss: 0.0121 - mae: 0.0448 - val_loss: 0.0220 - val_mae: 0.0599
Epoch 23/50
157/157 [==============================] - ETA: 0s - loss: 0.0119 - mae: 0.0445
Idx chosen: 235
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 95s 605ms/step - loss: 0.0119 - mae: 0.0445 - val_loss: 0.0206 - val_mae: 0.0578
Epoch 24/50
157/157 [==============================] - ETA: 0s - loss: 0.0119 - mae: 0.0444
Idx chosen: 157
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 95s 605ms/step - loss: 0.0119 - mae: 0.0444 - val_loss: 0.0216 - val_mae: 0.0593
Epoch 25/50
157/157 [==============================] - ETA: 0s - loss: 0.0115 - mae: 0.0435
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Idx chosen: 37
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 94s 601ms/step - loss: 0.0115 - mae: 0.0435 - val_loss: 0.0210 - val_mae: 0.0581
Epoch 26/50
157/157 [==============================] - ETA: 0s - loss: 0.0114 - mae: 0.0430
Idx chosen: 129
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 95s 607ms/step - loss: 0.0114 - mae: 0.0430 - val_loss: 0.0217 - val_mae: 0.0594
Epoch 27/50
157/157 [==============================] - ETA: 0s - loss: 0.0112 - mae: 0.0426
Idx chosen: 191
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 94s 601ms/step - loss: 0.0112 - mae: 0.0426 - val_loss: 0.0208 - val_mae: 0.0576
Epoch 28/50
157/157 [==============================] - ETA: 0s - loss: 0.0107 - mae: 0.0414
Idx chosen: 187
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 95s 603ms/step - loss: 0.0107 - mae: 0.0414 - val_loss: 0.0214 - val_mae: 0.0577
Epoch 29/50
157/157 [==============================] - ETA: 0s - loss: 0.0107 - mae: 0.0412
Idx chosen: 20
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 95s 603ms/step - loss: 0.0107 - mae: 0.0412 - val_loss: 0.0198 - val_mae: 0.0549
Epoch 30/50
157/157 [==============================] - ETA: 0s - loss: 0.0105 - mae: 0.0407
Idx chosen: 160
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 95s 604ms/step - loss: 0.0105 - mae: 0.0407 - val_loss: 0.0211 - val_mae: 0.0574
Epoch 31/50
157/157 [==============================] - ETA: 0s - loss: 0.0103 - mae: 0.0402
Idx chosen: 203
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 95s 606ms/step - loss: 0.0103 - mae: 0.0402 - val_loss: 0.0209 - val_mae: 0.0566
Epoch 32/50
157/157 [==============================] - ETA: 0s - loss: 0.0099 - mae: 0.0392
Idx chosen: 57
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 95s 605ms/step - loss: 0.0099 - mae: 0.0392 - val_loss: 0.0214 - val_mae: 0.0581
Epoch 33/50
157/157 [==============================] - ETA: 0s - loss: 0.0098 - mae: 0.0388
Idx chosen: 21
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 94s 600ms/step - loss: 0.0098 - mae: 0.0388 - val_loss: 0.0193 - val_mae: 0.0540
Epoch 34/50
157/157 [==============================] - ETA: 0s - loss: 0.0097 - mae: 0.0385
Idx chosen: 252
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 95s 605ms/step - loss: 0.0097 - mae: 0.0385 - val_loss: 0.0202 - val_mae: 0.0553
Epoch 35/50
157/157 [==============================] - ETA: 0s - loss: 0.0096 - mae: 0.0382
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Idx chosen: 235
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 94s 602ms/step - loss: 0.0096 - mae: 0.0382 - val_loss: 0.0208 - val_mae: 0.0568
Epoch 36/50
157/157 [==============================] - ETA: 0s - loss: 0.0093 - mae: 0.0374
Idx chosen: 88
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 94s 601ms/step - loss: 0.0093 - mae: 0.0374 - val_loss: 0.0194 - val_mae: 0.0541
Epoch 37/50
157/157 [==============================] - ETA: 0s - loss: 0.0092 - mae: 0.0370
Idx chosen: 48
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 94s 601ms/step - loss: 0.0092 - mae: 0.0370 - val_loss: 0.0186 - val_mae: 0.0523
Epoch 38/50
157/157 [==============================] - ETA: 0s - loss: 0.0090 - mae: 0.0366
Idx chosen: 218
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 95s 603ms/step - loss: 0.0090 - mae: 0.0366 - val_loss: 0.0189 - val_mae: 0.0533
Epoch 39/50
157/157 [==============================] - ETA: 0s - loss: 0.0089 - mae: 0.0362
Idx chosen: 58
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 95s 606ms/step - loss: 0.0089 - mae: 0.0362 - val_loss: 0.0191 - val_mae: 0.0530
Epoch 40/50
157/157 [==============================] - ETA: 0s - loss: 0.0087 - mae: 0.0357
Idx chosen: 254
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 95s 606ms/step - loss: 0.0087 - mae: 0.0357 - val_loss: 0.0186 - val_mae: 0.0522
Epoch 41/50
157/157 [==============================] - ETA: 0s - loss: 0.0086 - mae: 0.0353
Idx chosen: 169
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 95s 605ms/step - loss: 0.0086 - mae: 0.0353 - val_loss: 0.0184 - val_mae: 0.0517
Epoch 42/50
157/157 [==============================] - ETA: 0s - loss: 0.0085 - mae: 0.0350
Idx chosen: 255
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 95s 605ms/step - loss: 0.0085 - mae: 0.0350 - val_loss: 0.0184 - val_mae: 0.0516
Epoch 43/50
157/157 [==============================] - ETA: 0s - loss: 0.0083 - mae: 0.0347
Idx chosen: 219
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 95s 607ms/step - loss: 0.0083 - mae: 0.0347 - val_loss: 0.0184 - val_mae: 0.0515
Epoch 44/50
157/157 [==============================] - ETA: 0s - loss: 0.0082 - mae: 0.0344
Idx chosen: 187
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 95s 603ms/step - loss: 0.0082 - mae: 0.0344 - val_loss: 0.0183 - val_mae: 0.0513
Epoch 45/50
157/157 [==============================] - ETA: 0s - loss: 0.0081 - mae: 0.0341
Idx chosen: 207
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 95s 603ms/step - loss: 0.0081 - mae: 0.0341 - val_loss: 0.0182 - val_mae: 0.0512
Epoch 46/50
157/157 [==============================] - ETA: 0s - loss: 0.0081 - mae: 0.0339
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Idx chosen: 14
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 95s 605ms/step - loss: 0.0081 - mae: 0.0339 - val_loss: 0.0182 - val_mae: 0.0512
Epoch 47/50
157/157 [==============================] - ETA: 0s - loss: 0.0081 - mae: 0.0340
Idx chosen: 189
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 96s 613ms/step - loss: 0.0081 - mae: 0.0340 - val_loss: 0.0180 - val_mae: 0.0510
Epoch 48/50
157/157 [==============================] - ETA: 0s - loss: 0.0081 - mae: 0.0342
Idx chosen: 189
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 95s 607ms/step - loss: 0.0081 - mae: 0.0342 - val_loss: 0.0181 - val_mae: 0.0515
Epoch 49/50
157/157 [==============================] - ETA: 0s - loss: 0.0081 - mae: 0.0344
Idx chosen: 174
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 95s 606ms/step - loss: 0.0081 - mae: 0.0344 - val_loss: 0.0182 - val_mae: 0.0521
Epoch 50/50
157/157 [==============================] - ETA: 0s - loss: 0.0081 - mae: 0.0352
Idx chosen: 189
 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
157/157 [==============================] - 95s 606ms/step - loss: 0.0081 - mae: 0.0352 - val_loss: 0.0182 - val_mae: 0.0530
40/40 [==============================] - 6s 151ms/step - loss: 0.0178 - mae: 0.0523
Loss: 0.02
MAE: 0.05
In [ ]:
loss, mae = pimae_model.evaluate(test_ds, callbacks=[PIMAECustomCallback()])
print(f"Loss: {loss:.2f}")
print(f"MAE: {mae:.2f}")
Idx chosen: 50
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
40/40 [==============================] - ETA: 0s - loss: 0.0178 - mae: 0.0525 
PIMAE Stop testing; got log keys: ['loss', 'mae']
 
40/40 [==============================] - 6s 151ms/step - loss: 0.0178 - mae: 0.0525
Loss: 0.02
MAE: 0.05
In [ ]:
 
In [ ]:
patch_encoder.alreadyMasked = True  # Swtich the downstream flag to True.
In [ ]:
test_ds
Out[ ]:
<_PrefetchDataset element_spec=TensorSpec(shape=(None, 32, 32, 3), dtype=tf.uint8, name=None)>
In [ ]:
import cv2 as cv
import numpy as np
from matplotlib import pyplot as plt
In [ ]:
import numpy as np

def rescale_image(image):
    # Ensure image data is in floating-point format for accurate scaling
    image = image.astype(np.float32)

    # Find the minimum and maximum pixel values in the image
    min_value = np.min(image)
    max_value = np.max(image)

    # Rescale the image to have pixel values ranging from 0 to 255
    rescaled_image = 255.0 * (image - min_value) / (max_value - min_value)

    # Convert the rescaled image back to integer format (0 to 255)
    rescaled_image = rescaled_image.astype(np.uint8)

    return rescaled_image
In [ ]:
class PIMAEBlockCustomCallback(keras.callbacks.Callback):

    def on_test_begin(self, logs=None):

        print('Changed the PIMAE Callback after training')

        test_augmented_images = self.model.test_augmentation_model(test_images)

        test_patches = self.model.patch_layer(test_augmented_images)
        (
            test_unmasked_embeddings,
            test_masked_embeddings,
            test_unmasked_positions,
            test_mask_indices,
            test_unmask_indices,
        ) = self.model.patch_encoder(test_patches)

        test_encoder_outputs = self.model.encoder(test_unmasked_embeddings)
        test_encoder_outputs = test_encoder_outputs + test_unmasked_positions
        test_decoder_inputs = tf.concat(
            [test_encoder_outputs, test_masked_embeddings], axis=1
        )
        test_decoder_outputs = self.model.decoder(test_decoder_inputs)
        test_masked_patch, idx = self.model.patch_encoder.generate_masked_image(
            test_patches, test_unmask_indices
        )
        original_image = test_augmented_images[idx]
        masked_image = self.model.patch_layer.reconstruct_from_patch(
            test_masked_patch
        )
        reconstructed_image = test_decoder_outputs[idx]

        fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(15, 5))
        ax[0].imshow(img_mask, cmap='gray')
        ax[0].set_title(f"Scan Pattern:")

        ax[1].imshow(img, cmap='gray')
        ax[1].set_title(f"Input Image:")

        ax[2].imshow(reconstructed_image, cmap='gray')
        ax[2].set_title(f"Resonstructed:")

        for a in ax:
            a.set_xticks([])
            a.set_yticks([])

        fig_path = os.path.join(os.getcwd(), f'PI-MAE-Scans/Results-From-Running-Model/75-Noise-Results/{fig_fn}')
        plt.savefig(fig_path, dpi=150)
        plt.show()
        plt.close()
In [ ]:
import imageio.v2 as imageio
import numpy as np
import cv2

import os

failed = []

MEMS_Masked_75 = os.path.join(os.getcwd(), 'PI-MAE-Scans/NOISE-MASK/75mask-noise')
Thresh_75 = os.path.join(os.getcwd(), 'PI-MAE-Scans/NOISE-MASK/75mask-noise')

for file in os.listdir(Thresh_75):

    if not ('cv2' in file):
        continue

    if file[-3:] != 'png':
        continue
    fig_fn = file
    img_fn = os.path.join(Thresh_75, file)
    print('img_fn', img_fn)
    img = cv.imread(img_fn, cv.IMREAD_GRAYSCALE)
    obj, mask_percent, _, idx, _ = file.split('-')
    mask_fn = f'{obj}-{mask_percent}-1-{idx}-plt-mask.png'
    mask_fn = os.path.join(MEMS_Masked_75, mask_fn)
    print('mask_fn', mask_fn)
    img_mask = cv.imread(mask_fn, cv.IMREAD_GRAYSCALE)

    arr = img*~img_mask

    if arr.shape[1] == 32:
        arr = rescale_image(arr)

    patch_encoder.tmp_x = img_mask
    patch_encoder.setMasking = True

    x_ingas = np.array([arr])
    y_ingas = np.array([0])

    x_ingas = np.expand_dims(x_ingas, -1)
    y_ingas = np.expand_dims(y_ingas, -1)

    x_ingas = tf.image.grayscale_to_rgb(
        tf.convert_to_tensor(x_ingas),
        name=None
    )

    ingas_ds = tf.data.Dataset.from_tensor_slices(x_ingas)
    ingas_ds = ingas_ds.batch(BATCH_SIZE).prefetch(AUTO)

    x_ingas_ds = tf.data.Dataset.from_tensor_slices(x_ingas)
    x_ingas_ds = ingas_ds.batch(BATCH_SIZE).prefetch(AUTO)

    test_images = next(iter(ingas_ds))

    try:
        pimae_model.evaluate(test_ds, callbacks=[PIMAEBlockCustomCallback()])
    except:
        failed.append(file)
img_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/Q-75-1-4-cv2.png
mask_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/Q-75-1-4-plt-mask.png
Changed the PIMAE Callback after training
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
40/40 [==============================] - 6s 151ms/step - loss: 0.0178 - mae: 0.0532
img_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/V-75-1-12-cv2.png
mask_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/V-75-1-12-plt-mask.png
Changed the PIMAE Callback after training
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
40/40 [==============================] - 6s 152ms/step - loss: 0.0180 - mae: 0.0520
img_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/N-75-1-6-cv2.png
mask_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/N-75-1-6-plt-mask.png
Changed the PIMAE Callback after training
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
40/40 [==============================] - 6s 150ms/step - loss: 0.0170 - mae: 0.0513
img_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/2-75-1-1-cv2.png
mask_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/2-75-1-1-plt-mask.png
Changed the PIMAE Callback after training
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
40/40 [==============================] - 6s 150ms/step - loss: 0.0206 - mae: 0.0561
img_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/O-75-1-3-cv2.png
mask_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/O-75-1-3-plt-mask.png
Changed the PIMAE Callback after training
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
40/40 [==============================] - 6s 149ms/step - loss: 0.0214 - mae: 0.0605
img_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/D-75-1-10-cv2.png
mask_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/D-75-1-10-plt-mask.png
Changed the PIMAE Callback after training
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
40/40 [==============================] - 6s 149ms/step - loss: 0.0231 - mae: 0.0626
img_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/X-75-1-16-cv2.png
mask_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/X-75-1-16-plt-mask.png
Changed the PIMAE Callback after training
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
40/40 [==============================] - 6s 150ms/step - loss: 0.0154 - mae: 0.0478
img_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/3-75-1-0-cv2.png
mask_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/3-75-1-0-plt-mask.png
Changed the PIMAE Callback after training
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
40/40 [==============================] - 6s 150ms/step - loss: 0.0209 - mae: 0.0586
img_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/6-75-1-11-cv2.png
mask_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/6-75-1-11-plt-mask.png
Changed the PIMAE Callback after training
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
40/40 [==============================] - 6s 152ms/step - loss: 0.0181 - mae: 0.0535
img_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/Z-75-1-8-cv2.png
mask_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/Z-75-1-8-plt-mask.png
Changed the PIMAE Callback after training
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
40/40 [==============================] - 6s 150ms/step - loss: 0.0169 - mae: 0.0503
img_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/5-75-1-10-cv2.png
mask_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/5-75-1-10-plt-mask.png
Changed the PIMAE Callback after training
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
40/40 [==============================] - 6s 151ms/step - loss: 0.0166 - mae: 0.0503
img_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/L-75-1-15-cv2.png
mask_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/L-75-1-15-plt-mask.png
Changed the PIMAE Callback after training
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
40/40 [==============================] - 6s 151ms/step - loss: 0.0170 - mae: 0.0509
img_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/W-75-1-5-cv2.png
mask_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/W-75-1-5-plt-mask.png
Changed the PIMAE Callback after training
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
40/40 [==============================] - 6s 150ms/step - loss: 0.0159 - mae: 0.0492
img_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/F-75-1-23-cv2.png
mask_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/F-75-1-23-plt-mask.png
Changed the PIMAE Callback after training
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
40/40 [==============================] - 6s 149ms/step - loss: 0.0137 - mae: 0.0435
img_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/P-75-1-21-cv2.png
mask_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/P-75-1-21-plt-mask.png
Changed the PIMAE Callback after training
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
40/40 [==============================] - 6s 149ms/step - loss: 0.0193 - mae: 0.0550
img_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/I-75-1-15-cv2.png
mask_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/I-75-1-15-plt-mask.png
Changed the PIMAE Callback after training
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
40/40 [==============================] - 6s 151ms/step - loss: 0.0152 - mae: 0.0465
img_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/C-75-1-0-cv2.png
mask_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/C-75-1-0-plt-mask.png
Changed the PIMAE Callback after training
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
40/40 [==============================] - 6s 149ms/step - loss: 0.0190 - mae: 0.0541
img_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/1-75-1-10-cv2.png
mask_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/1-75-1-10-plt-mask.png
Changed the PIMAE Callback after training
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
40/40 [==============================] - 6s 149ms/step - loss: 0.0170 - mae: 0.0507
img_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/H-75-1-4-cv2.png
mask_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/H-75-1-4-plt-mask.png
Changed the PIMAE Callback after training
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
40/40 [==============================] - 6s 149ms/step - loss: 0.0169 - mae: 0.0510
img_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/A-75-1-18-cv2.png
mask_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/A-75-1-18-plt-mask.png
Changed the PIMAE Callback after training
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
40/40 [==============================] - 6s 151ms/step - loss: 0.0176 - mae: 0.0518
img_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/7-75-1-1-cv2.png
mask_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/7-75-1-1-plt-mask.png
Changed the PIMAE Callback after training
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
40/40 [==============================] - 6s 149ms/step - loss: 0.0147 - mae: 0.0461
img_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/B-75-1-2-cv2.png
mask_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/B-75-1-2-plt-mask.png
Changed the PIMAE Callback after training
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
40/40 [==============================] - 6s 150ms/step - loss: 0.0209 - mae: 0.0586
img_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/M-75-1-20-cv2.png
mask_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/M-75-1-20-plt-mask.png
Changed the PIMAE Callback after training
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
40/40 [==============================] - 6s 150ms/step - loss: 0.0144 - mae: 0.0455
img_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/G-75-1-29-cv2.png
mask_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/G-75-1-29-plt-mask.png
Changed the PIMAE Callback after training
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
40/40 [==============================] - 6s 151ms/step - loss: 0.0149 - mae: 0.0471
img_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/U-75-1-11-cv2.png
mask_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/U-75-1-11-plt-mask.png
Changed the PIMAE Callback after training
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
40/40 [==============================] - 6s 150ms/step - loss: 0.0174 - mae: 0.0524
img_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/K-75-1-25-cv2.png
mask_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/K-75-1-25-plt-mask.png
Changed the PIMAE Callback after training
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
40/40 [==============================] - 6s 150ms/step - loss: 0.0203 - mae: 0.0571
img_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/T-75-1-8-cv2.png
mask_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/T-75-1-8-plt-mask.png
Changed the PIMAE Callback after training
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
40/40 [==============================] - 6s 152ms/step - loss: 0.0188 - mae: 0.0536
img_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/Y-75-1-22-cv2.png
mask_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/Y-75-1-22-plt-mask.png
Changed the PIMAE Callback after training
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
40/40 [==============================] - 6s 151ms/step - loss: 0.0154 - mae: 0.0459
img_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/0-75-1-2-cv2.png
mask_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/0-75-1-2-plt-mask.png
Changed the PIMAE Callback after training
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
40/40 [==============================] - 6s 150ms/step - loss: 0.0177 - mae: 0.0525
img_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/9-75-1-2-cv2.png
mask_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/9-75-1-2-plt-mask.png
Changed the PIMAE Callback after training
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
40/40 [==============================] - 6s 149ms/step - loss: 0.0143 - mae: 0.0453
img_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/S-75-1-21-cv2.png
mask_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/S-75-1-21-plt-mask.png
Changed the PIMAE Callback after training
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
40/40 [==============================] - 6s 152ms/step - loss: 0.0149 - mae: 0.0472
img_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/8-75-1-11-cv2.png
mask_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/8-75-1-11-plt-mask.png
Changed the PIMAE Callback after training
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
40/40 [==============================] - 6s 150ms/step - loss: 0.0122 - mae: 0.0412
img_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/R-75-1-3-cv2.png
mask_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/R-75-1-3-plt-mask.png
Changed the PIMAE Callback after training
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
40/40 [==============================] - 6s 151ms/step - loss: 0.0164 - mae: 0.0501
img_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/E-75-1-4-cv2.png
mask_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/E-75-1-4-plt-mask.png
Changed the PIMAE Callback after training
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
40/40 [==============================] - 6s 151ms/step - loss: 0.0176 - mae: 0.0516
img_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/4-75-1-5-cv2.png
mask_fn /home/ec2-user/SageMaker/Dialate MNIST/MAE/MAE/onPredict/PI-MAE-Scans/NOISE-MASK/75mask-noise/4-75-1-5-plt-mask.png
Changed the PIMAE Callback after training
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
40/40 [==============================] - 6s 149ms/step - loss: 0.0203 - mae: 0.0565
In [ ]:
 
In [ ]: